{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "j4qMNV_533ls"
      },
      "source": [
        "##### Copyright 2025 Google LLC."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "urf-mQKk348O"
      },
      "outputs": [],
      "source": [
        "# @title Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "iS5fT77w4ZNj"
      },
      "source": [
        "# Gemma - Activation Hacking\n",
        "\n",
        "Author: Sascha Rothe\n",
        "\n",
        "This Colaboratory notebook provides a comprehensive tutorial on interacting with\n",
        "a Gemma checkpoint. It will guide you through the process of loading the model,\n",
        "generating samples, and examining its internal states, including the residual\n",
        "stream, MLP activations, and attention mechanisms. Furthermore, an example\n",
        "demonstrating how to modify the model's behavior will be presented. While the\n",
        "notebook includes an illustrative example, we encourage you to adapt these\n",
        "experiments and explore the model's behavior to uncover novel insights.\n",
        "\n",
        "**Note: This colab was tested with a A100 GPU. You might see unexpected behaviour on a different hardware.**\n",
        "\n",
        "<table align=\"left\">\n",
        "  <td>\n",
        "    <a target=\"_blank\" href=\"https://colab.research.google.com/github/google-gemini/gemma-cookbook/blob/main/Gemma/[Gemma_3]Activation_Hacking.ipynb\"><img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a>\n",
        "  </td>\n",
        "</table>"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R4FrgYBWGdsJ"
      },
      "outputs": [],
      "source": [
        "#@title Install dependencies\n",
        "! pip install --no-deps -U flax\n",
        "! pip install jaxtyping kagglehub treescope"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "A0nmk8NwG4fK"
      },
      "source": [
        "To interact with the Gemma model, you will use the Flax NNX gemma code from\n",
        "google/flax examples on GitHub. Since it is not exposed as a package, you need\n",
        "to use the following workaround to import from the Flax NNX examples/gemma on\n",
        "GitHub."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "-jpTUa1YESaM"
      },
      "outputs": [],
      "source": [
        "#@title Python imports\n",
        "import html\n",
        "import os\n",
        "import sys\n",
        "import tempfile\n",
        "from flax import nnx\n",
        "from google.colab import userdata\n",
        "from IPython.display import HTML\n",
        "import jax\n",
        "import kagglehub\n",
        "from matplotlib import colors, pyplot\n",
        "import numpy as np\n",
        "import sentencepiece as spm\n",
        "\n",
        "with tempfile.TemporaryDirectory() as tmp:\n",
        "  # Create a temporary directory and clone the `flax` repo.\n",
        "  # Then, append the `examples/gemma` folder to the path for loading the `gemma` modules.\n",
        "  ! git clone https://github.com/google/flax.git {tmp}/flax\n",
        "  sys.path.append(f\"{tmp}/flax/examples/gemma\")\n",
        "  import modules\n",
        "  import params as params_lib\n",
        "  import sampler as sampler_lib\n",
        "  import sow_lib\n",
        "  import transformer as transformer_lib\n",
        "\n",
        "  sys.path.pop()"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "8xUHvwdHH0O1"
      },
      "source": [
        "To use Gemma model, you’ll need a Kaggle account and API key:\n",
        "\n",
        "1.  To create an account, visit Kaggle and click on ‘Register’.\n",
        "\n",
        "2.  If/once you have an account, you need to sign in, go to your ‘Settings’, and\n",
        "    under ‘API’ click on ‘Create New Token’ to generate and download your Kaggle\n",
        "    API key.\n",
        "\n",
        "3.  [Optional] In Google Colab, under ‘Secrets’ add your Kaggle username and API\n",
        "    key, storing the username as KAGGLE_USERNAME and the key as KAGGLE_KEY.\n",
        "\n",
        "4.  Request access to the model here:\n",
        "    https://www.kaggle.com/models/google/gemma-3"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lzbkuQfvwklN"
      },
      "outputs": [],
      "source": [
        "#@title Select a Gemma 3 model from kaggle\n",
        "# Surpress Noisy progress bar.\n",
        "%%capture captured_output --no-stdout\n",
        "\n",
        "VARIANT = 'gemma3-1b' # @param ['gemma3-1b', 'gemma3-1b-it', 'gemma3-4b', 'gemma3-4b-it', 'gemma3-12b', 'gemma3-12b-it', 'gemma3-27b', 'gemma3-27b-it'] {type:\"string\"}\n",
        "\n",
        "try:\n",
        "  os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
        "  os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
        "except userdata.SecretNotFoundError:\n",
        "  kagglehub.login()\n",
        "\n",
        "print(\"Downloading model ...\")\n",
        "weights_dir = kagglehub.model_download(f'google/gemma-3/flax/{VARIANT}')\n",
        "ckpt_path = f'{weights_dir}/{VARIANT}'\n",
        "vocab_path = f'{weights_dir}/tokenizer.model'\n",
        "print(\"Done.\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4fDQsC87ESaN"
      },
      "source": [
        "## Start Setting up Your Model\n",
        "\n",
        "Prepare the Sow Configuration and set the intermediates you want to surface. For the tutorial, enable `embeddings`, `rs_after_ffw` and `rs_after_attention` for the residual stream. Also set `mlp_hidden_topk=10` to surface the activations in the MLP layers (also called feedforward layer)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "7InOzQtcESaN"
      },
      "outputs": [],
      "source": [
        "embeddings = False  # @param {\"type\":\"boolean\"}\n",
        "rs_after_ffw = False  # @param {\"type\":\"boolean\"}\n",
        "rs_after_attention = False  # @param {\"type\":\"boolean\"}\n",
        "mlp_hidden_topk = 0  # @param {\"type\":\"integer\"}\n",
        "attn_logits_topk = 0  # @param {\"type\":\"integer\"}\n",
        "\n",
        "sow_config = sow_lib.SowConfig(\n",
        "    embeddings=embeddings,\n",
        "    rs_after_ffw=rs_after_ffw,\n",
        "    rs_after_attention=rs_after_attention,\n",
        "    mlp_hidden_topk=mlp_hidden_topk,\n",
        "    attn_logits_topk=attn_logits_topk,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KaU-X3_jESaN"
      },
      "source": [
        "Now, build the model and required modules."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "bdstASGrESaN"
      },
      "outputs": [],
      "source": [
        "# Load parameters\n",
        "params = params_lib.load_and_format_params(ckpt_path)\n",
        "\n",
        "# Tokenizer\n",
        "vocab = spm.SentencePieceProcessor()\n",
        "vocab.Load(vocab_path)\n",
        "\n",
        "# Transformer Model\n",
        "transformer = transformer_lib.Transformer.from_params(\n",
        "    params, sow_config=sow_config\n",
        ")\n",
        "\n",
        "# Sampler\n",
        "sampler = sampler_lib.Sampler(\n",
        "    transformer=transformer,\n",
        "    vocab=vocab,\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "C1fLns-_ESaN"
      },
      "source": [
        "You're ready to start generating! As we want to visualize the intermediate\n",
        "activations, we are using a batch size of 1. We limit the models output by setting `total_generation_steps = 1`, so we can investigate it more easily."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "qA0BhNQvESaN"
      },
      "outputs": [],
      "source": [
        "def run_model(prompt):\n",
        "  prompt_length = len(vocab.EncodeAsIds(prompt)) + 1\n",
        "  total_generation_steps = 1\n",
        "\n",
        "  out_data = sampler(\n",
        "      input_strings=[prompt],\n",
        "      echo=True,  # This returns the prompt as well.\n",
        "      total_generation_steps=total_generation_steps,\n",
        "  )\n",
        "  print(out_data.text[0])\n",
        "\n",
        "  out_length = np.count_nonzero(out_data.tokens)\n",
        "\n",
        "  return prompt_length, out_length, out_data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "i8UVzaiia1pR"
      },
      "outputs": [],
      "source": [
        "prompt_length, out_length, out_data = run_model(\"What is the capital of Switzerland? Answer:\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FwoJy8RyGA3g"
      },
      "source": [
        "## Investigate the model\n",
        "\n",
        "In order to investigate intermediate outputs we map them into the text space by performing premature decoding. To be precise we have the apply the final norm and the softmax projection. We create a helper function `premature_decode`."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Dkrm0w9ElT_E"
      },
      "outputs": [],
      "source": [
        "def format_token(string):\n",
        "  string = string.replace('▁', ' ')\n",
        "  string = string.replace('<', ' <')\n",
        "  string = html.escape(string)\n",
        "  return string\n",
        "\n",
        "\n",
        "def id_to_token(i):\n",
        "  return format_token(vocab.IdToPiece(i))\n",
        "\n",
        "\n",
        "def premature_decode(residual_stream):\n",
        "  residual_stream = transformer.final_norm(residual_stream)\n",
        "  logits = transformer.embedder.decode(residual_stream)\n",
        "  _, token_ids = jax.lax.top_k(logits, 10)\n",
        "  tokens = []\n",
        "  for top10_token_ids in token_ids:\n",
        "    tokens.append([id_to_token(int(id)) for id in top10_token_ids])\n",
        "    if top10_token_ids[0] == vocab.eos_id():\n",
        "      break\n",
        "  return tokens"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "InI1ODEoGnKW"
      },
      "source": [
        "To visualize the residual stream we create:\n",
        "\n",
        "*   a green row for the embedding layer\n",
        "*   a red row for each attention layer\n",
        "*   a blue row for each ffw layer\n",
        "\n",
        "You can see how the output evolves from bottom to top. Darker the colors\n",
        "indicate a bigger change during this layer in the residual stream. Note that in\n",
        "the early layers the premature token is simply the previous token. This is\n",
        "because the residual stream was initialized with the embedding of the previous\n",
        "token. You might also find some language code switching between the layers.\n",
        "For the example *What is the capital of Switzerland? Answer: Bern* you should see that the answer *Bern* was only decided on the the very last feedforward layer."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "N7CMfuIwnR1B"
      },
      "outputs": [],
      "source": [
        "# @title Display residual stream\n",
        "\n",
        "\n",
        "def get_style(prob, color):\n",
        "  cmap = pyplot.get_cmap(color)\n",
        "  rgb = cmap(prob)[:3]\n",
        "  bg = colors.rgb2hex(rgb)\n",
        "  fg = 'black' if sum(rgb) > 1.5 else 'white'\n",
        "  return 'color: {};background-color: {};'.format(fg, bg)\n",
        "\n",
        "\n",
        "def cosine_similarity(a, b):\n",
        "  return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))\n",
        "\n",
        "\n",
        "def get_premature_tokens(\n",
        "    residual_stream, previous_rs, title, color, length=out_length\n",
        "):\n",
        "  batch_size, _, _ = residual_stream.shape\n",
        "  assert batch_size == 1\n",
        "  previous_rs = previous_rs[0, 1:length, :]\n",
        "  premature_rs = residual_stream[0, 1:length, :]\n",
        "  premature_tokens = premature_decode(premature_rs)\n",
        "  tds = ''\n",
        "  for i, top10_premature_tokens in enumerate(premature_tokens):\n",
        "    value = 1 - cosine_similarity(premature_rs[i], previous_rs[i])\n",
        "    top_premature_token, *others = top10_premature_tokens\n",
        "    title_for_token = title + '\\n' + '/'.join(others)\n",
        "    tds += \"<td title='{}' style='{}'>{}</td>\".format(\n",
        "        title_for_token, get_style(value * 70, color), top_premature_token\n",
        "    )\n",
        "  return tds\n",
        "\n",
        "\n",
        "def print_premature_layers(intermediates):\n",
        "  trs = []\n",
        "  previous_rs = np.ones_like(intermediates.embeddings)\n",
        "  tds = get_premature_tokens(\n",
        "      intermediates.embeddings,\n",
        "      previous_rs,\n",
        "      'After Embedding:',\n",
        "      color='Greens',\n",
        "  )\n",
        "  previous_rs = intermediates.embeddings\n",
        "  trs.append(f'<tr>{tds}</tr>')\n",
        "  for layer_id, layer in enumerate(intermediates.layers):\n",
        "    tds = get_premature_tokens(\n",
        "        layer.rs_after_attention,\n",
        "        previous_rs,\n",
        "        f'After Attention {layer_id}:',\n",
        "        color='Reds',\n",
        "    )\n",
        "    trs.append(f'<tr>{tds}</tr>')\n",
        "    previous_rs = layer.rs_after_attention\n",
        "    tds = get_premature_tokens(\n",
        "        layer.rs_after_ffw,\n",
        "        previous_rs,\n",
        "        f'After FFW {layer_id}:',\n",
        "        color='Blues',\n",
        "    )\n",
        "    trs.append(f'<tr>{tds}</tr>')\n",
        "    previous_rs = layer.rs_after_ffw\n",
        "  tds = get_premature_tokens(\n",
        "      intermediates.embeddings[:, 1:, :],\n",
        "      previous_rs,\n",
        "      'Forced:',\n",
        "      color='Greens',\n",
        "      length=prompt_length,\n",
        "  )\n",
        "  trs.append(f'<tr>{tds}</tr>')\n",
        "  trs.reverse()\n",
        "  html_string = f'<table>{\"\".join(trs)}</table>'\n",
        "  return HTML(html_string)\n",
        "\n",
        "\n",
        "print_premature_layers(out_data.intermediates)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "c9hYsqK-CTVW"
      },
      "source": [
        "The next cell will visualize the top activated neurons in a given layer. You can hover over the colored blocks to get the neuron id and value. For the last feedfordward layer _25_ you should see that the top activated neurons are *1937* and *4422*."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "QQg-V4yZiOlu"
      },
      "outputs": [],
      "source": [
        "# @title Activation in Feedforward Layers\n",
        "layer = 0  # @param {\"type\":\"integer\"}\n",
        "\n",
        "\n",
        "def get_activations_line(step, token_id, intermediates):\n",
        "  tr = '<td>{}</td>'.format(id_to_token(token_id))\n",
        "  values = intermediates.layers[layer].mlp_hidden_topk_values[0, step, :]\n",
        "  indices = intermediates.layers[layer].mlp_hidden_topk_indices[0, step, :]\n",
        "  mouseover_texts = [f'Layer: {layer}'] * 70\n",
        "  colors = [0.0] * 70\n",
        "  for value, neuron in zip(values, indices):\n",
        "    neuron = int(neuron)\n",
        "    mouseover_texts[neuron // 100] += f'\\nNeuron: {neuron}, Value: {value:3.2f}'\n",
        "    colors[neuron // 100] += value / values[0]\n",
        "  for mouseover_text, color in zip(mouseover_texts, colors):\n",
        "    style = get_style(color, 'Blues')\n",
        "    tr += f\"<td title='{mouseover_text}' style='{style}'>&nbsp;&nbsp;</td>\"\n",
        "\n",
        "  return '<tr>{}</tr>'.format(tr)\n",
        "\n",
        "\n",
        "def print_activations(tokens, intermediates):\n",
        "  html_string = ''\n",
        "  for step, token in enumerate(tokens):\n",
        "    html_string += get_activations_line(step, token, intermediates)\n",
        "    if token == vocab.eos_id():\n",
        "      break\n",
        "  html_string = f'<table>{html_string}</table>'\n",
        "  return HTML(html_string)\n",
        "\n",
        "\n",
        "print_activations(out_data.tokens[0], out_data.intermediates)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "30A0KOHTCiVB"
      },
      "source": [
        "After identifiying a neuron of interest we can deactive or boost it. You can play around with the bias values to achieve different behaviours.\n",
        "1.  **Deactivate** neuron 1937 in layer 25 and issue the same prompt again.\n",
        "2.  **Boost** neuron 1937 in layer 25 and your model should repsonse with *Switzerland* significantly more often.\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "PNGZjQS_du_8"
      },
      "outputs": [],
      "source": [
        "#@title Mask/Boost ar single neuron\n",
        "layer = 0  # @param {\"type\":\"integer\"}\n",
        "neuron = 0  # @param {\"type\":\"integer\"}\n",
        "operation = \"none\"  # @param [\"none\",\"deactivate\",\"boost\"]\n",
        "\n",
        "gate_bias = np.zeros(6912)\n",
        "if operation == \"none\":\n",
        "  gate_bias[neuron] = 0.0\n",
        "elif operation == \"boost\":\n",
        "  gate_bias[neuron] = 1.0\n",
        "elif operation == \"deactivate\":\n",
        "  gate_bias[neuron] = -1000.0\n",
        "transformer.layers[layer].mlp.gate_proj.use_bias = True\n",
        "transformer.layers[layer].mlp.gate_proj.bias = nnx.Variable(gate_bias)\n",
        "print(f\"Gate bias of neuron {neuron} in layer {layer} set to {gate_bias[neuron]}\")\n",
        "\n",
        "if operation == \"boost\":\n",
        "  up_bias = np.zeros(6912)\n",
        "  up_bias[neuron] = 20.0\n",
        "  transformer.layers[layer].mlp.up_proj.use_bias = True\n",
        "  transformer.layers[layer].mlp.up_proj.bias = nnx.Variable(up_bias)\n",
        "  print(f\"Up proj bias of neuron {neuron} in layer {layer} set to {up_bias[neuron]}\")\n",
        "\n",
        "sampler = sampler_lib.Sampler(\n",
        "    transformer=transformer,\n",
        "    vocab=vocab,\n",
        ")\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "CNWJ-8YzDCxw"
      },
      "source": [
        "We can also apply the `premature_decode` functions to the value of a neuron to investigate the effect of a neuron. Check neuron *1937* of layer *25* to verify why the model behaviour changed by deactivating or boosting the neuron."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "6zK3SVRHFwDg"
      },
      "outputs": [],
      "source": [
        "#@title Print top k associated tokens.\n",
        "layer = 0  # @param {\"type\":\"integer\"}\n",
        "neuron = 0  # @param {\"type\":\"integer\"}\n",
        "variable = \"down_proj\"  # @param [\"up_proj\",\"gate_proj\",\"down_proj\"]\n",
        "\n",
        "if variable == \"up_proj\":\n",
        "  embedding = transformer.layers[layer].mlp.up_proj.kernel[:, neuron]\n",
        "elif variable == \"gate_proj\":\n",
        "  embedding = transformer.layers[layer].mlp.gate_proj.kernel[:, neuron]\n",
        "else:\n",
        "  embedding = transformer.layers[layer].mlp.down_proj.kernel[neuron, :]\n",
        "\n",
        "normalized_stream = transformer.final_norm(embedding)\n",
        "logits = transformer.embedder.decode(normalized_stream)\n",
        "_, token_ids = jax.lax.top_k(logits, 20)\n",
        "for i, token_id in enumerate(token_ids):\n",
        "  print(f\"{i}: {id_to_token(int(token_id))}\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "4XsgjqDpH3Tv"
      },
      "source": [
        "We can also take a look at the attention mechanism. For the purpose of this\n",
        "tutorial we take the last layer and average over all heads."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "cellView": "form",
        "id": "2Cjc8YMM-fov"
      },
      "outputs": [],
      "source": [
        "#@title Attention visualization\n",
        "html_header = \"\"\"\n",
        "<style>\n",
        "  span:hover {\n",
        "    color:white !important;\n",
        "    background-color:#00670d !important;\n",
        "  }\n",
        "</style>\n",
        "<script>\n",
        "  const componentToHex = (c) => {\n",
        "    const hex = c.toString(16);\n",
        "    return hex.length == 1 ? \"0\" + hex : hex;\n",
        "  }\n",
        "  function hueToRgb(t1, t2, hue) {\n",
        "    if (hue < 0) hue += 6;\n",
        "    if (hue >= 6) hue -= 6;\n",
        "    if (hue < 1) return (t2 - t1) * hue + t1;\n",
        "    else if(hue < 3) return t2;\n",
        "    else if(hue < 4) return (t2 - t1) * (4 - hue) + t1;\n",
        "    else return t1;\n",
        "  }\n",
        "  function hslToRgb(hue, sat, light) {\n",
        "    var t1, t2, r, g, b;\n",
        "    hue = hue / 60;\n",
        "    if ( light <= 0.5 ) {\n",
        "      t2 = light * (sat + 1);\n",
        "    } else {\n",
        "      t2 = light + sat - (light * sat);\n",
        "    }\n",
        "    t1 = light * 2 - t2;\n",
        "    r = hueToRgb(t1, t2, hue + 2) * 255;\n",
        "    g = hueToRgb(t1, t2, hue) * 255;\n",
        "    b = hueToRgb(t1, t2, hue - 2) * 255;\n",
        "    r = componentToHex(Math.floor(r));\n",
        "    g = componentToHex(Math.floor(g));\n",
        "    b = componentToHex(Math.floor(b));\n",
        "    console.log(r, g, b);\n",
        "    return `#${r}${g}${b}`;\n",
        "  }\n",
        "  function getStyle(value, h=0) {\n",
        "    value = Math.min(value, 1.0)\n",
        "    fg = value < 0.7 ? \"black\" : \"white\";\n",
        "    bg = hslToRgb(h, 1.0, 1-(value*0.8));\n",
        "    return `color: ${fg}; background-color: ${bg}`;\n",
        "  }\n",
        "  function showAttention(self, atten_probs) {\n",
        "    for (i = 0; i < atten_probs.length; i++) {\n",
        "      atten_prob = atten_probs[i]\n",
        "      document.getElementById('token_' + i).title = `Value: ${atten_prob}`;\n",
        "      document.getElementById('token_' + i).style.cssText = getStyle(atten_prob);\n",
        "    }\n",
        "    self.style.cssText = 'color: white; background-color: #00670d';\n",
        "  }\n",
        "</script>\"\"\"\n",
        "\n",
        "\n",
        "def print_attention(token_ids, intermediates):\n",
        "  attention_html = '<p>'\n",
        "  output_html = '<p>'\n",
        "  for i, token_id in enumerate(token_ids):\n",
        "    # Get topk attention values of current token and unsparsify.\n",
        "    all_atten_probs = []\n",
        "    # Last layer only.\n",
        "    for layer in intermediates.layers[-1:]:\n",
        "      # Average over all heads.\n",
        "      _, _, num_heads, _ = layer.attn_logits_topk_values.shape\n",
        "      for head in range(num_heads):\n",
        "        atten_logits = modules.K_MASK * np.ones(out_length)\n",
        "        for value, index in zip(\n",
        "            layer.attn_logits_topk_values[0, i, head, :],\n",
        "            layer.attn_logits_topk_indices[0, i, head, :],\n",
        "        ):\n",
        "          atten_logits[index] = value\n",
        "        # The models tends to attend a lot towards the BOS token. Mask this\n",
        "        # to have a more meaningful visualization.\n",
        "        atten_logits[0] = modules.K_MASK\n",
        "        # Note that this softmax is an approximation as we have topks only.\n",
        "        atten_probs = np.exp(atten_logits) / (\n",
        "            sum(np.exp(atten_logits)) + 0.000001\n",
        "        )\n",
        "        all_atten_probs.append(atten_probs)\n",
        "    avg_atten_probs = np.sum(all_atten_probs, axis=0)\n",
        "    avg_atten_probs /= np.max(avg_atten_probs) + 0.000001\n",
        "\n",
        "    token = id_to_token(token_id)\n",
        "    atten_probs_string = (\n",
        "        np.array2string(avg_atten_probs, separator=',')\n",
        "        .replace('\\n', '')\n",
        "        .replace(' ', '')\n",
        "    )\n",
        "    onclick = f'showAttention(this, {atten_probs_string})'\n",
        "    attention_html += \"<span id='token_{}' onclick='{}'>{}</span>\".format(\n",
        "        i, onclick, token\n",
        "    )\n",
        "\n",
        "    if token_id == vocab.eos_id():\n",
        "      break\n",
        "\n",
        "  attention_html += '</p>'\n",
        "  output_html += '</p>'\n",
        "\n",
        "  return HTML(html_header + attention_html)\n",
        "\n",
        "\n",
        "print_attention(out_data.tokens[0], out_data.intermediates)"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "name": "[Gemma_3]Activation_Hacking.ipynb",
      "toc_visible": true
    },
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
